{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "071c7845",
   "metadata": {},
   "source": [
    "下面介绍基于循环神经网络的编码器和解码器的代码实现。首先是作为编码器的循环神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f682c7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        # 隐层大小\n",
    "        self.hidden_size = hidden_size\n",
    "        # 词表大小\n",
    "        self.vocab_size = vocab_size\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size,\\\n",
    "            self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs: batch * seq_len\n",
    "        # 注意门控循环单元使用batch_first=True，因此输入需要至少batch为1\n",
    "        features = self.embedding(inputs)\n",
    "        output, hidden = self.gru(features)\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dc8af74",
   "metadata": {},
   "source": [
    "接下来是作为解码器的另一个循环神经网络的代码实现。<!--我们使用编码器最终的输出用作解码器的初始隐状态，这个输出向量有时称作上下文向量（context vector），它编码了整个源序列的信息。解码器最初的输入词元是“\\<sos\\>”（start-of-string）。解码器的目标为，输入编码器隐状态，一步一步解码出整个目标序列。解码器每一步的输入可以是真实目标序列，也可以是解码器上一步的预测结果。-->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9beee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        # 序列到序列任务并不限制编码器和解码器输入同一种语言，\n",
    "        # 因此解码器也需要定义一个嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将输出的隐状态映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1,\\\n",
    "            dtype=torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden = self.forward_step(\\\n",
    "                decoder_input, decoder_hidden)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 为了与AttnRNNDecoder接口保持统一，最后输出None\n",
    "        return decoder_outputs, decoder_hidden, None\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden):\n",
    "        output = self.embedding(input)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38306ed9",
   "metadata": {},
   "source": [
    "下面介绍基于注意力机制的循环神经网络解码器的代码实现。\n",
    "我们使用一个注意力层来计算注意力权重，其输入为解码器的输入和隐状态。\n",
    "这里使用Bahdanau注意力（Bahdanau attention），这是序列到序列模型中应用最广泛的注意力机制，特别是机器翻译任务。该注意力机制使用一个对齐模型（alignment model）来计算编码器和解码器隐状态之间的注意力分数，具体来讲就是一个前馈神经网络。相比于点乘注意力，Bahdanau注意力利用了非线性变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d675aa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        # query: batch * 1 * hidden_size\n",
    "        # keys: batch * seq_length * hidden_size\n",
    "        # 这一步用到了广播（broadcast）机制\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "\n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "        return context, weights\n",
    "\n",
    "class AttnRNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(AttnRNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        # 输入来自解码器输入和上下文向量，因此输入大小为2 * hidden_size\n",
    "        self.gru = nn.GRU(2 * self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将注意力的结果映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=\\\n",
    "            torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "\n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden, attn_weights = \\\n",
    "                self.forward_step(\n",
    "                    decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # 与RNNDecoder接口保持统一，最后输出注意力权重\n",
    "        return decoder_outputs, decoder_hidden, attentions\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        embeded =  self.embedding(input)\n",
    "        # 输出的隐状态为1 * batch * hidden_size，\n",
    "        # 注意力的输入需要batch * 1 * hidden_size\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embeded, context), dim=2)\n",
    "        # 输入的隐状态需要1 * batch * hidden_size\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden, attn_weights\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d63b031",
   "metadata": {},
   "source": [
    "\n",
    "接下来我们实现基于Transformer的编码器和解码器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4a5c4c46-c230-4b0d-a647-06461ac4a2ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defaulting to user installation because normal site-packages is not writeable\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ERROR: Could not find a version that satisfies the requirement transformer (from versions: none)\n",
      "ERROR: No matching distribution found for transformer\n"
     ]
    }
   ],
   "source": [
    "pip install transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c9a91cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "from transformer import *\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 直接使用TransformerLayer作为编码层，简单起见只使用一层\n",
    "        self.layer = TransformerLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 与TransformerLM不同，编码器不需要线性层用于输出\n",
    "        \n",
    "    def forward(self, input_ids):\n",
    "        # 这里实现的forward()函数一次只能处理一句话，\n",
    "        # 如果想要支持批次运算，需要根据输入序列的长度返回隐状态\n",
    "        assert input_ids.ndim == 2 and input_ids.size(0) == 1\n",
    "        seq_len = input_ids.size(1)\n",
    "        assert seq_len <= self.embedding_layer.max_len\n",
    "        \n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        attention_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        input_states = self.embedding_layer(input_ids, pos_ids)\n",
    "        hidden_states = self.layer(input_states, attention_mask)\n",
    "        return hidden_states, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "385b2c7e-b1a2-42b1-ad1c-d1f4a2f682fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5e8cfc9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadCrossAttention(MultiHeadSelfAttention):\n",
    "    def forward(self, tgt, tgt_mask, src, src_mask):\n",
    "        \"\"\"\n",
    "        tgt: query, batch_size * tgt_seq_len * hidden_size\n",
    "        tgt_mask: batch_size * tgt_seq_len\n",
    "        src: keys/values, batch_size * src_seq_len * hidden_size\n",
    "        src_mask: batch_size * src_seq_len\n",
    "        \"\"\"\n",
    "        # (batch_size * num_heads) * seq_len * (hidden_size / num_heads)\n",
    "        queries = self.transpose_qkv(self.W_q(tgt))\n",
    "        keys = self.transpose_qkv(self.W_k(src))\n",
    "        values = self.transpose_qkv(self.W_v(src))\n",
    "        # 这一步与自注意力不同，计算交叉掩码\n",
    "        # batch_size * tgt_seq_len * src_seq_len\n",
    "        attention_mask = tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)\n",
    "        # 重复张量的元素，用以支持多个注意力头的运算\n",
    "        # (batch_size * num_heads) * tgt_seq_len * src_seq_len\n",
    "        attention_mask = torch.repeat_interleave(attention_mask,\\\n",
    "            repeats=self.num_heads, dim=0)\n",
    "        # (batch_size * num_heads) * tgt_seq_len * \\\n",
    "        # (hidden_size / num_heads)\n",
    "        output = self.attention(queries, keys, values, attention_mask)\n",
    "        # batch * tgt_seq_len * hidden_size\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)\n",
    "\n",
    "# TransformerDecoderLayer比TransformerLayer多了交叉多头注意力\n",
    "class TransformerDecoderLayer(nn.Module):\n",
    "    def __init__(self, hidden_size, num_heads, dropout,\\\n",
    "                 intermediate_size):\n",
    "        super().__init__()\n",
    "        self.self_attention = MultiHeadSelfAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm1 = AddNorm(hidden_size, dropout)\n",
    "        self.enc_attention = MultiHeadCrossAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm2 = AddNorm(hidden_size, dropout)\n",
    "        self.fnn = PositionWiseFNN(hidden_size, intermediate_size)\n",
    "        self.add_norm3 = AddNorm(hidden_size, dropout)\n",
    "\n",
    "    def forward(self, src_states, src_mask, tgt_states, tgt_mask):\n",
    "        # 掩码多头自注意力\n",
    "        tgt = self.add_norm1(tgt_states, self.self_attention(\\\n",
    "            tgt_states, tgt_states, tgt_states, tgt_mask))\n",
    "        # 交叉多头自注意力\n",
    "        tgt = self.add_norm2(tgt, self.enc_attention(tgt,\\\n",
    "            tgt_mask, src_states, src_mask))\n",
    "        # 前馈神经网络\n",
    "        return self.add_norm3(tgt, self.fnn(tgt))\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "                 dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 简单起见只使用一层\n",
    "        self.layer = TransformerDecoderLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 解码器与TransformerLM一样，需要输出层\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, src_states, src_mask, tgt_tensor=None):\n",
    "        # 确保一次只输入一句话，形状为1 * seq_len * hidden_size\n",
    "        assert src_states.ndim == 3 and src_states.size(0) == 1\n",
    "        \n",
    "        if tgt_tensor is not None:\n",
    "            # 确保一次只输入一句话，形状为1 * seq_len\n",
    "            assert tgt_tensor.ndim == 2 and tgt_tensor.size(0) == 1\n",
    "            seq_len = tgt_tensor.size(1)\n",
    "            assert seq_len <= self.embedding_layer.max_len\n",
    "        else:\n",
    "            seq_len = self.embedding_layer.max_len\n",
    "        \n",
    "        decoder_input = torch.empty(1, 1, dtype=torch.long).\\\n",
    "            fill_(SOS_token)\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            decoder_output = self.forward_step(decoder_input,\\\n",
    "                src_mask, src_states)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            \n",
    "            if tgt_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    tgt_tensor[:, i:i+1]), 1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    topi.squeeze(-1).detach()), 1)\n",
    "                \n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 与RNNDecoder接口保持统一\n",
    "        return decoder_outputs, None, None\n",
    "        \n",
    "    # 解码一步，与RNNDecoder接口略有不同，RNNDecoder一次输入\n",
    "    # 一个隐状态和一个词，输出一个分布、一个隐状态\n",
    "    # TransformerDecoder不需要输入隐状态，\n",
    "    # 输入整个目标端历史输入序列，输出一个分布，不输出隐状态\n",
    "    def forward_step(self, tgt_inputs, src_mask, src_states):\n",
    "        seq_len = tgt_inputs.size(1)\n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        tgt_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        tgt_states = self.embedding_layer(tgt_inputs, pos_ids)\n",
    "        hidden_states = self.layer(src_states, src_mask, tgt_states,\\\n",
    "            tgt_mask)\n",
    "        output = self.output_layer(hidden_states[:, -1:, :])\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99b58199-3be0-4fa2-bd59-dc12476e6713",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0e58b811",
   "metadata": {},
   "source": [
    "下面以机器翻译（中-英）为例展示如何训练序列到序列模型。这里使用的是中英文Books数据，其中中文标题来源于第4章所使用的数据集，英文标题是使用已训练好的机器翻译模型从中文标题翻译而得，因此该数据并不保证准确性，仅用于演示。\n",
    "\n",
    "首先需要对源语言和目标语言分别建立索引，并记录词频。\n",
    "\n",
    "<!-- 代码来自：\n",
    "\n",
    "https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html <span style=\"color:blue;font-size:20px\">BSD-3-Clause license</span>\n",
    "\n",
    "下载链接：\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.en\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.zh\n",
    " -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6b258fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"<sos>\", 1: \"<eos>\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "    def sent2ids(self, sent):\n",
    "        return [self.word2index[word] for word in sent.split(' ')]\n",
    "    \n",
    "    def ids2sent(self, ids):\n",
    "        return ' '.join([self.index2word[idx] for idx in ids])\n",
    "\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "# 文件使用unicode编码，我们将unicode转为ASCII，转为小写，并修改标点\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    # 在标点前插入空格\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1\", s)\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "01ae1e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件，一共有两个文件，两个文件的同一行对应一对源语言和目标语言句子\n",
    "def readLangs(lang1, lang2):\n",
    "    # 读取文件，分句\n",
    "    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    print(len(lines1), len(lines2))\n",
    "    \n",
    "    # 规范化\n",
    "    lines1 = [normalizeString(s) for s in lines1]\n",
    "    lines2 = [normalizeString(s) for s in lines2]\n",
    "    if lang1 == 'zh':\n",
    "        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]\n",
    "    if lang2 == 'zh':\n",
    "        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]\n",
    "    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]\n",
    "\n",
    "    input_lang = Lang(lang1)\n",
    "    output_lang = Lang(lang2)\n",
    "    return input_lang, output_lang, pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d6a561e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n",
      "['大 师 谈 游 戏 设 计 创 意 与 节 奏', 'masters talk about game design , creativity and rhythm .']\n"
     ]
    }
   ],
   "source": [
    "# 为了快速训练，过滤掉一些过长的句子\n",
    "MAX_LENGTH = 30\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "def prepareData(lang1, lang2):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2)\n",
    "    print(f\"读取 {len(pairs)} 对序列\")\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(f\"过滤后剩余 {len(pairs)} 对序列\")\n",
    "    print(\"统计词数\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfa6874b",
   "metadata": {},
   "source": [
    "为了便于训练，对每一对源-目标句子需要准备一个源张量（源句子的词元索引）和一个目标张量（目标句子的词元索引）。在两个句子的末尾会添加“\\<eos\\>”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "92baa7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n"
     ]
    }
   ],
   "source": [
    "def get_train_data():\n",
    "    input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "    train_data = []\n",
    "    for idx, (src_sent, tgt_sent) in enumerate(pairs):\n",
    "        src_ids = input_lang.sent2ids(src_sent)\n",
    "        tgt_ids = output_lang.sent2ids(tgt_sent)\n",
    "        # 添加<eos>\n",
    "        src_ids.append(EOS_token)\n",
    "        tgt_ids.append(EOS_token)\n",
    "        train_data.append([src_ids, tgt_ids])\n",
    "    return input_lang, output_lang, train_data\n",
    "        \n",
    "input_lang, output_lang, train_data = get_train_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411c9ed3",
   "metadata": {},
   "source": [
    "<!--训练时，我们使用编码器编码源句子，并且保留每一步的输出向量和最终的隐状态。然后解码器以“\\<sos\\>”作为第一个输入，以及编码器的最终隐状态作为它的第一个隐状态。-->\n",
    "接下来是训练代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8b85d1ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch-19, loss=0.2426: 100%|█| 20/20 [12:59<00:00, 38.97s/it\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAABMl0lEQVR4nO3deXgT1f4G8Dfd0n1foNBSoKyFsoNlE2UHEb3en4p4BdeLgogICioKerW4g4joVS+4ISoiqAVkB9mhUCh7Cy0t0AW6LzRtk/P7o3Zo2qRJ2qSTpO/nefI8yeTMyXdIa1/PnDmjEEIIEBEREVkhB7kLICIiItKHQYWIiIisFoMKERERWS0GFSIiIrJaDCpERERktRhUiIiIyGoxqBAREZHVcpK7gMbQaDS4du0avLy8oFAo5C6HiIiIjCCEQFFREUJDQ+HgUP+YiU0HlWvXriEsLEzuMoiIiKgB0tPT0bp163rb2HRQ8fLyAlB1oN7e3jJXQ0RERMYoLCxEWFiY9He8PjYdVKpP93h7ezOoEBER2Rhjpm1wMi0RERFZLQYVIiIisloMKkRERGS1GFSIiIjIajGoEBERkdViUCEiIiKrxaBCREREVotBhYiIiKwWgwoRERFZLQYVIiIisloMKkRERGS1GFSIiIjIajGo6FFWoYYQQu4yiIiImjUGFR0uXS9G5wWb8eLak3KXQkRE1KwxqOjw5d4UAMDP8VdkroSIiKh5Y1DRwVPpJHcJREREBAYVndxdHOUugYiIiCBzUFGr1ViwYAHatm0LNzc3tG/fHm+++absk1hrjqjIXQsREVFzJus5jnfeeQcrVqzA119/jaioKBw9ehSPPvoofHx8MHPmTNnq6hfhLz0vLVfDg6eCiIiIZCHrX+D9+/dj4sSJGD9+PAAgIiICP/zwAw4fPixnWejeykd6rqrUwEMpYzFERETNmKynfgYOHIjt27fjwoULAIATJ05g7969GDt2rM72KpUKhYWFWg9LcHBQwEFR9bxSrbHIZxAREZFhso6ozJs3D4WFhejcuTMcHR2hVqvx1ltvYfLkyTrbx8bGYtGiRU1Sm5ODA8rVGlRqOEeFiIhILrKOqPz000/4/vvvsXr1ahw7dgxff/013n//fXz99dc628+fPx8FBQXSIz093WK1OTlWDamoGVSIiIhkI+uIyty5czFv3jw8+OCDAIDu3bvj8uXLiI2NxZQpU+q0VyqVUCqbZsKI49/nfjiiQkREJB9ZR1RKS0vh4KBdgqOjIzQa+eeFVAcVtRXUQkRE1FzJOqIyYcIEvPXWWwgPD0dUVBSOHz+ODz/8EI899picZQEA/p5LCy6jQkREJB9Zg8qyZcuwYMECPPPMM8jOzkZoaCj+/e9/47XXXpOzLACAg6IqqjCnEBERyUfWoOLl5YUlS5ZgyZIlcpah0985BRoOqRAREcmG9/rRQ1E9osKcQkREJBsGFT2q56hwRIWIiEg+DCp6OHBEhYiISHYMKnpUz1FhUCEiIpIPg4oet676YVIhIiKSC4OKAVyYloiISD4MKnrcOvXDpEJERCQXBhU9qk/9cESFiIhIPgwqelSPqHBtWiIiIvkwqOjBERUiIiL5MajowZsSEhERyY9BRQ/e64eIiEh+DCp68F4/RERE8mNQ0cOBlycTERHJjkFFDwWqV6YlIiIiuTCo6ME5KkRERPJjUNGDc1SIiIjkx6CihwNHVIiIiGTHoKKHdK8fecsgIiJq1hhU9JBWpuXStERERLJhUNGDK9MSERHJj0FFD2kyrcx1EBERNWcMKnpwMi0REZH8GFT04OXJRERE8mNQ0YNL6BMREcmPQUWP6hEVXvRDREQkHwYVPaSrfjidloiISDYMKno4cESFiIhIdgwqeig4R4WIiEh2DCp6OPCqHyIiItnJGlQiIiKgUCjqPKZPny5nWQBujahwHRUiIiL5OMn54UeOHIFarZZenzp1CiNHjsT//d//yVhVFa6jQkREJD9Zg0pQUJDW68WLF6N9+/a4/fbbZaroFq5MS0REJD9Zg0pN5eXl+O677zB79mxpNKM2lUoFlUolvS4sLLRYPZyjQkREJD+rmUy7fv165OfnY+rUqXrbxMbGwsfHR3qEhYVZrB6uo0JERCQ/qwkqX331FcaOHYvQ0FC9bebPn4+CggLpkZ6ebrF6uDItERGR/Kzi1M/ly5exbds2rFu3rt52SqUSSqWySWriHBUiIiL5WcWIysqVKxEcHIzx48fLXYrk1oJv8tZBRETUnMkeVDQaDVauXIkpU6bAyckqBngA1JxMy6RCREQkF9mDyrZt25CWlobHHntM7lK0SEFF5jqIiIiaM9mHMEaNGmWdoxbVc1Q4m5aIiEg2so+oWCvePZmIiEh+DCp6VF/1w5xCREQkHwYVPaQF36zxtBQREVEzwaCix61TPwwqREREcmFQ0YN3TyYiIpIfg4oeCmllWnnrICIias4YVPS4UVx1l+a9yddlroSIiKj5YlDRY9f5qoCyLzlH5kqIiIiaLwYVIiIisloMKkRERGS1GFSIiIjIajGoEBERkdViUCEiIiKrxaBCREREVotBhYiIiKwWg4oeq58YIHcJREREzR6Dih6t/Nyk59Wr1BIREVHTYlDRo/ruyQCQVVgGIQT2Jd9AdlGZjFURERE1L05yF2CtHB1uBZWnvonH4MhA/Hg0HQCQuni8XGURERE1KxxR0aNmULmaf1MKKURERNR0GFT0qHHmh4iIiGTCoKKHI5MKERGR7BhU9HBx0v9PcyWvtAkrISIiar4YVPTwcnXW+97gd3Y2YSVERETNF4MKERERWS0GlQYqKquQuwQiIiK7x6DSQF/suQQAOHW1ANvOZMlcDRERkX1iUKnHqK4het8rUlUCAO5athdPfHMUx9LymqosIiKiZoNBpR7Ojvr/eRxqXb68dFuSpcshIiJqdhhU6hGXmGF02z5t/CxYCRERUfPEoNJAX+1NQVmFWnrdusbdlomIiMg8ZA8qV69excMPP4yAgAC4ubmhe/fuOHr0qNxlGaV6Qi0ACCFjIURERHZK1rsn5+XlYdCgQbjjjjuwadMmBAUFISkpCX5+tnEaJSm7WHrOnEJERGR+sgaVd955B2FhYVi5cqW0rW3btnrbq1QqqFQq6XVhYaFF6zPktxPXpOeCQypERERmJ+upn99++w19+/bF//3f/yE4OBi9evXCF198obd9bGwsfHx8pEdYWFgTVls/5hQiIiLzkzWoXLp0CStWrECHDh3w559/4umnn8bMmTPx9ddf62w/f/58FBQUSI/09PQmrli/F385icyCMrnLICIisisKIeM5CxcXF/Tt2xf79++Xts2cORNHjhzBgQMHDO5fWFgIHx8fFBQUwNvb2+z1jVmyB+cyi4xvH9UCn/2rj9nrICIisiem/P2WdUSlZcuW6Nq1q9a2Ll26IC0tTaaKtBkKKT5u2ndY3nw6k8vpExERmZGsQWXQoEE4f/681rYLFy6gTZs2MlVkGicHRZ1tT3xjG5dWExER2QJZg8rzzz+PgwcP4u2330ZycjJWr16N//73v5g+fbqcZUlCvJX1vp9TUt5ElRARETVPsgaVfv364ddff8UPP/yAbt264c0338SSJUswefJkOcuSLLirq+FGREREZDGyTqZtLEtPphVC4PS1QizbkYQ/Txs39yTc3x17XrzD7LUQERHZC5uZTGvtFAoFurXywWcP90HCayMR7u9ucB+1xmZzHxERkdVhUDGCQqGAr7sL8oyYk3I1/2YTVERERNQ8MKiYwNmJ/1xERERNiX95TeDsWPdyZCIiIrIcBhUTODnwn4uIiKgp8S+vCRx1LPBGRERElsOgYgJPpZPcJRARETUrDComcOFkWiIioibFv7wmaOXrZlQ7IQRiN53Fz0fTLVwRERGRfWNQMcHgDoFGtTtwKQef776EuWtPWrgiIiIi+8agYoLe4X5Gtdtw/JqFKyEiImoeGFRM0KmFl1HtfuQpHyIiIrNgULGwiHlxKCitkLsMIiIim8Sg0gR6vLEF649flbsMIiIim8Og0kRm/ZggdwlEREQ2h0GlCWk0AptPZWDHuSy5SyEiIrIJDComGhPVosH7PrvmOKZ9dwyPrTqKpKwiM1ZFRERknxhUTNSY+xLGncyQnl+8XmKGaoiIiOwbg4pM1BohdwlERERWj0FFJmrBoEJERGQIg4qJjF2d1hC1RmOWfoiIiOwZg4qJpgyMwJv3dMO22bc3qp9d56+bqSIiIiL7xaBiImdHB/zrtjaIDPZsVD8bEng/ICIiIkMYVIiIiMhqMajIZHBkoNwlEBERWT0GFZnsTb4hdwlERERWj0FFJoGeLnKXQEREZPUYVGTy/MiOcpdARERk9RhUiIiIyGoxqMjEQaGQuwQiIiKrJ2tQWbhwIRQKhdajc+fOcpbUZNycHeUugYiIyOo5yV1AVFQUtm3bJr12cpK9pAZ5+LZwfHcwzej2Gt7rh4iIyCDZU4GTkxNatGhhVFuVSgWVSiW9LiwstFRZJvvPPd1x7HI+zmQYVxNvnkxERGSY7HNUkpKSEBoainbt2mHy5MlIS9M/KhEbGwsfHx/pERYW1oSVGrbumYFGt+WIChERkWGyBpUBAwZg1apV2Lx5M1asWIGUlBQMGTIERUVFOtvPnz8fBQUF0iM9Pb2JK66fqwnzTgSDChERkUGynvoZO3as9Dw6OhoDBgxAmzZt8NNPP+Hxxx+v016pVEKpVDZliRbDUz9ERESGyX7qpyZfX1907NgRycnJcpdicTz1Q0REZJhVBZXi4mJcvHgRLVu2lLsUozTmxoIcUSEiIjJM1qAyZ84c7N69G6mpqdi/fz/uvfdeODo6YtKkSXKWZTRnx4Yv2sY5KkRERIbJOkflypUrmDRpEnJychAUFITBgwfj4MGDCAoKkrMsozk5NjznaTikQkREZJCsQWXNmjVyfnyjNWZEhTmFiIjIMKuao2JrnBwaMaLCUz9EREQGMag0Qv+2/nW2xf6juwyVEBER2SfZl9C3ZZP6h8PF0QF9I/ykbe4uxi36ZsyISoVaA+dGzIMhIiKydfwr2AiODgrc3y8M7YI8pW3dWvnUafdA37pL/Ruao3Lvp/vQ4ZVNOHQpp9F1EhER2SoGFTNrH+SJ32YM0tr2zj+jMXlAuNY2QyMqx9PyAQAP/PegWesjIiKyJQwqFhDd2rfOttr3AeJcWiIiIsMYVJqIq7P2P7Wa1ycTEREZxKBiYeOjq24H4OqkPaKi69TP5lOZeHTlYeQUq5qkNiIiImvHq34s7La/L2GufepH14jKtO/iAQCxm85ZvjAiIiIbwKBiIWunxWBv8g1M6l81ibb2qZ/Kek79rI2/YtHaiIiIbAWDioX0jfBH34hbC8IpjRhRISIiIm2co9JEap/6qVRrB5XEKwVNWQ4REZFNYFBpIkqn2lf9aCCEwPazWbiSV4pX1ifWu78QAhevF/Ouy0RE1KwwqDQRt9ojKhqBrWey8PjXRzH4nZ317iuEwJd/pWD4B7vxyvpTSM4uQuzGs8gtKbdkyURERLLjHJUmUntEZce5bK37+Cjq2Xf/xRy8tfEsAOCHw2n45dgVlFdqcPF6Cb6c0hfXi1RYfSgN9/drjZY+bpYon4iISBYcUWki3Vtr3wMoo6DM6Am1k788pPW6vFIDANh2NgunrhZg4id78dG2Cxi79C/zFEtERGQlGFSaiLuLE86+MUZrW2FZRaP7vWvZXlwrKAMA5JdWIPPv50RERPaAQaUJubloz1PZkHBNep5ZaJ6A8cLPCWbph4iIyBowqFiJrELzLJt/IavYLP0QERFZAwaVJuZQ36xZIiIi0sKg0sQGtA2waP867nVIRERksxoUVL7++mvExcVJr1988UX4+vpi4MCBuHz5stmKs0ezRnSwaP83eOdlIiKyIw0KKm+//Tbc3KrW6zhw4ACWL1+Od999F4GBgXj++efNWqC98VBy6RoiIiJjNeivZnp6OiIjIwEA69evx3333YennnoKgwYNwrBhw8xZn91xUHCSChERkbEaNKLi6emJnJwcAMCWLVswcuRIAICrqytu3rxpvurskANnBRERERmtQSMqI0eOxBNPPIFevXrhwoULGDduHADg9OnTiIiIMGd9die32LL356l9TyEiIiJb1qD/v1++fDliYmJw/fp1/PLLLwgIqLqSJT4+HpMmTTJrgfbGHKvR1odnloiIyJ40aETF19cXn3zySZ3tixYtanRB9q57a1+L9s85MEREZE8aNKKyefNm7N27V3q9fPly9OzZEw899BDy8vLMVpw9crbwim+MKUREZE8aFFTmzp2LwsJCAEBiYiJeeOEFjBs3DikpKZg9e7ZZCyQTMakQEZEdaVBQSUlJQdeuXQEAv/zyC+666y68/fbbWL58OTZt2tSgQhYvXgyFQoFZs2Y1aH9bEeSltGj/zClERGRPGhRUXFxcUFpaCgDYtm0bRo0aBQDw9/eXRlpMceTIEXz++eeIjo5uSDk2RcE5JEREREZrUFAZPHgwZs+ejTfffBOHDx/G+PHjAQAXLlxA69atTeqruLgYkydPxhdffAE/P7+GlEM1MAgREZE9aVBQ+eSTT+Dk5IS1a9dixYoVaNWqFQBg06ZNGDNmjEl9TZ8+HePHj8eIESMMtlWpVCgsLNR6kDbmFCIisicNujw5PDwcf/zxR53tH330kUn9rFmzBseOHcORI0eMah8bG8tLoImIiJqRBt8hT61WY/369Th79iwAICoqCnfffTccHY1bGTU9PR3PPfcctm7dCldXV6P2mT9/vtZVRYWFhQgLCzO9eDvGARUiIrInDQoqycnJGDduHK5evYpOnToBqBrtCAsLQ1xcHNq3b2+wj/j4eGRnZ6N3797SNrVajT179uCTTz6BSqWqE3qUSiWUSsteNWPr8kotu/ItERFRU2pQUJk5cybat2+PgwcPwt/fHwCQk5ODhx9+GDNnzkRcXJzBPoYPH47ExEStbY8++ig6d+6Ml156yeiRGSIiIrJfDQoqu3fv1gopABAQEIDFixdj0KBBRvXh5eWFbt26aW3z8PBAQEBAne3NyTeP9ccj/zssdxlERERWoUFX/SiVShQVFdXZXlxcDBcXl0YX1ZwN7RgkdwlERERWo0EjKnfddReeeuopfPXVV+jfvz8A4NChQ5g2bRruvvvuBheza9euBu9LRERE9qdBIyoff/wx2rdvj5iYGLi6usLV1RUDBw5EZGQklixZYuYSyRQdgj3lLoGIiMhsGjSi4uvriw0bNiA5OVm6PLlLly6IjIw0a3FkOkcL352ZiIioKRkdVAzdFXnnzp3S8w8//LDhFZGWhRO6Ysn2JOQbednxucy6c4eIiIhsldFB5fjx40a1471mDBsUGYB9yTlGtZ06qC0eiYlAu5c36nx/YPsAtPZzw09Hr5izRCIiIqtgdFCpOWJCjfPq+K4Yu/QvAMBDA8Kx+lBave0d6jmds/rJ27DmcBqDChER2aUGTaYl83l1fJcG77tsUi8A9QcZIiIiW8agIjOHRpwq6xjiBQBw5Ok2IiKyUwwqNsyR3x4REdk5/qmTgVojpOdOJp62eWhAOACgd7gv2gdVrZnSqYWX+YojIiKyIg1aR4Uap2ZQMXXdk9Z+bkhdPF5rW7dWPugQ7Imk7GKz1EdERGQtOKIiM4VCgb5t/AAAd3YONti+Ui10bp9+h/kX2xvw9jZEzIuDELo/k4iIyNI4oiKDbq180DvcF6393AEA/32kLzYmZmBCj1Ctdm0C3KXnwV5KZBep9IaZ6tM/rs5V2VNVqYbSyVGrzae7krE36Qb+N7UfXJ0d6/RRU06xClmFKgDAgYs5GBgZaMIREhERmQeDigwcHRRY98wg6bW/hwsevq1NnXY1TwrtmjsM14tUaBPgobdPAPBwccLC305j1f5UbJw5BF1DvQEA7/95Hp/sTAYA/Hr8Kib1r5rrUqHWIK+0HMFerlr9VdY4PXWjpNz0gyQiIjIDnvqxQr7uzgCAwR1ujWK4uzjpDSkAUD3VRSMEVu1PBQB8vD0JAFCp1kghBagKJ9Xu/mQf+r+1HecyC7X6qzl3pqJSAyIiIjlwRMUK/fHsYGw5nYUH+4cZvU/1rQtqDIRIYWP3hetabWuGkLMZVQHlro/3Ynx0S3x0f084OCi0RnMCPF1MPAIiIiLz4IiKFWrt547HBreFu4vxObJ64ThNjaSiUACpN0pQWq7WaqtrgbhKjcCGhGvYk1QVamoGniAvpSnlExERmQ1HVOyEozSicithxCVm4I+TGXXaFpVV6u3nck4p7lm+D4MiA6RtGp75ISIimTCo2InqQZKSGqMn+q4qfmvjWTzYPwzOOpa2XbzpHG5WqJGQni9tU/PyZCIikgmDip0w9caE3RdugbtL3UuUb1ao62xbG5+OnmG+DS2NiIiowThHxU405AbKteeu6PPdwTTTOyciIjIDBhU7wTsoExGRPWJQsRM1F2gjIiKyFwwqdqKwrELuEoiIiMyOQcVOODVkkgoREZGVY1CxE6YsDkdERGQrGFTsRDBXjyUiIjvEoGInHCx81c+VvFKL9k9ERKQLg4qdMHXBN1PN/OG4RfsnIiLShUGFjJKawxEVIiJqegwqZBReVERERHKQNaisWLEC0dHR8Pb2hre3N2JiYrBp0yY5SyI9bhSXo0Sl/67LREREliBrUGndujUWL16M+Ph4HD16FHfeeScmTpyI06dPy1kW6bFsR7LcJRARUTMj6+IbEyZM0Hr91ltvYcWKFTh48CCioqJkqor0ScstkbsEIiJqZqxmlTC1Wo2ff/4ZJSUliImJ0dlGpVJBpVJJrwsLC5uqPAKw41y23CUQEVEzI/tk2sTERHh6ekKpVGLatGn49ddf0bVrV51tY2Nj4ePjIz3CwsKauNrmraxCI3cJRETUzMgeVDp16oSEhAQcOnQITz/9NKZMmYIzZ87obDt//nwUFBRIj/T09CauloiIiJqS7Kd+XFxcEBkZCQDo06cPjhw5gqVLl+Lzzz+v01apVEKp5FLxREREzYXsIyq1aTQarXkoRERE1HzJOqIyf/58jB07FuHh4SgqKsLq1auxa9cu/Pnnn3KWRURERFZC1qCSnZ2NRx55BBkZGfDx8UF0dDT+/PNPjBw5Us6yqIkIIXA+qwhtAz2gdHKUuxwiIrJCsgaVr776Ss6PJxkJIfBz/BW8uPYkbu8YhK8f6y93SUREZIWsbo4KNVyPMF+DbdxdGj5y0SbAvcH71jZj9XG8uPYkAGD3hetm65eIiOwLg4od+WVaDPbPu7PeNideH4VLb49rUP8+bs4N2k+XuMQMs/VFRET2i0HFjjg5OiDU1w2rHu2nt42zowMcat0K2d/DRWdbFyftHw9XziMhIqImxqBih1r7uRls8/0TA6Tnao2o8/6ySb3w8IA2Wttua+ff+OKIiIhMwKBih9JySw22GRQZKD3X6Agq7i6O6NzSS2tbdpHp69tkF5Xp7J+IiMgYDCp2KCrUR+v17zMGY3jnYHw6ubfO9mpRN0g4KBTo28ZPa9uaI/XfsuDAxRx8tPWCNEKzN+kG+r+1Hf/+Ll6rXbGq0uAxEBERAVawhD6Zn1utK3uiQr3x1VT981Y0OoKKQlEVVkwx6YuDAIBWvm64v18YvvjrEgBg65ksrXZX8gyP+BAREQEcUbFLxWXaIxaG8oaOnAKFQmFyUKl2ObcEgHYAOpGej1/irzSoPyIiar44omKHvGtdRqwwEDh0BRIHheGAU1peiY2Jmbizc7DOK4f+SrohPZ+4fB8A4HJuKYZ1Cqq/YyIior9xRMUOuTqZ9rV2a+VdZ5sCijqXMVfbfeE60nJKsWD9acz5+QSmrjxs9Gd9vD0J//h0v0n1ERFR88URFTtUc4RkSIdAve32z7sTOcXl+PZgKo6k5tV536lWUBnYPgB3vL8LKTeqTu1Ur7Ny8kqBOcomIiKqg0HFDtU8ZbPi4T5624X6uiHU1w0K6B45cawVVPZfzNF6XV6p0bmfrjkvREREDcGgYocUCgVSYsdJzw1x0HOmyFnfGzrklZRLz387cQ23tQswel8iIiJ9GFTslDEBpb62CgXg5Gh8H898f0x6fiXvJh75n/HzVoiIiPThZFqCnjmzdU791OfApRzDjYiIiEzEoEI6L0+ODPaEsyN/PIiISF489UMY1ikI3xy4DGdHBXbNvQNFZRUI8XaVuywiIiIGFQLu6BSM1U8OQGSwJ4K9XAEYvvsyERFRU+DYPkGhUGBg+8C/Q4q2xIWjMH9sZxmqIiIiYlAhA7xcnRHkpZS7DCIiaqYYVMggU67+aSjBVeKIiEgHBhUyyJQ1WRpKrWFQISKiuhhUyCDHJggqzClERKQLgwoZ1BTLqWjMeOpn2fYkTFi2F8WqSrP1SURE8mBQIYN0LQhnbuYMKh9svYDEqwX4/uBls/VJRETyYFAhg5omqJi/zwq17rs7ExGR7WBQIYOa4qofTqYlIiJdGFTIIAcbvTyZVzwTEdk+BhUy6GhqrsU/gyMqRESkC4MKGVRWobb4Z1gipzD7EBHZPgYVMqhzC2+Lf4YlTv2oee6HiMjmyRpUYmNj0a9fP3h5eSE4OBj33HMPzp8/L2dJpENT/MG3xGd8vD0Jlbzyh4jIpskaVHbv3o3p06fj4MGD2Lp1KyoqKjBq1CiUlJTIWRbVMrB9gMU/w1KnaS5e588SEZEtc5Lzwzdv3qz1etWqVQgODkZ8fDyGDh1ap71KpYJKpZJeFxYWWrxGaprLk89cK0QrXzez9yvA0z9ERLbMquaoFBQUAAD8/f11vh8bGwsfHx/pERYW1pTlNVuNvdfPx5N6GQwhKTeKTerzkf8dRsS8OJy8kl/1GduTEDEvDi//mqjVjtNUiIhsm9UEFY1Gg1mzZmHQoEHo1q2bzjbz589HQUGB9EhPT2/iKpsnfw+XRu1/d49QzLgzst42Id6uJvW558L1qr4/2QcA+HDrBQDA6kNpWu0Kb1aY1C8REVkXqwkq06dPx6lTp7BmzRq9bZRKJby9vbUeZHlOjg74/okBaMwZIA9l/WcZO7Xwanjn9dhxPtsi/RIRUdOwiqAyY8YM/PHHH9i5cydat24tdzmkw6DIQCS9Na7B+7sYuAWzpU7R+Lg5W6ZjIiJqErIGFSEEZsyYgV9//RU7duxA27Zt5SyHDKg5qfblcZ3xzWP98eNTtxm1b69wX63XHi6OuLdXK+m1vrsnazQCR1JzUaKqNL1gAL3C/Bq0HxERWQdZr/qZPn06Vq9ejQ0bNsDLywuZmZkAAB8fH7i5mf8KEGq89dMH4a8L1/H44Hb1Xg3k4uiA8hprmIR4u2L/vDuRW1KOn4+m44kh7RDm7479F28gq1Cld0Rl5f5UvPnHGfRp44dfnh4ItUbU+dyl25L01qGqtPyqukREZDmyBpUVK1YAAIYNG6a1feXKlZg6dWrTF0QG9QzzRc8wX73vL5zQFaOiWmDlvhR88VeK1nuhvm4I9XVDt1Y+0jYFqkKHvqDyzYFUAED85TzsOJeFad8ew7v/jNZq89G2C3rraYrl/4mIyHJkDSqWWDad5PP6hK6YOqjq9N2MOzrg5JUCrdM7ulRf+axrvZP/7U3B5ZxS6fVjq44CAGb9mGB0TZW84Q8RkU2TNaiQfTjx+iiUqCoRWmOtFB93Z/z47xiD+zr8nVSq80RmQRkyCm7CQ+mEN/440+jajqbm4a7o0Eb3Q0RE8mBQoUbzcXNu8NU11fNN1H8nldtitwMAHh0UYZbaVu1PxcK7o8zSFxERNT2ruDyZmq/aQaXaznNc/4SIiBhUSGYpN6puGrjtbJbW9uIGXo5MRET2hUGFrMJ/91xCUdmt5e5vFJebpV8u+EZEZNsYVMhqPPzVYbP3+c8+XOmYiMiWMaiQ1TiRnm/2PmvPfSEiItvCoEJ2Td/S/EREZBsYVMiucUSFiMi2MaiQXeOIChGRbWNQIbtmrhGVSrUG+y/ewM1y3juIiKgpMaiQXatxA+dGWbo9CQ99cQjTVx8zT4dERGQUBhWya+Y49SOEwKp9qQCAHVwxl4ioSTGokF1rbFDJLSnHwMU7UMSVcomIZMGgQnatsXNUVu1PRUZBmZmqISIiUzGokF2rPaKSXVSG6d8fw/6LN4zaX60x0yQXIiJqEAYVktWCu7patP/aIyovrT2JuMQMPPTFIb1tatp57rrFaiMiIsMYVEhWwzsHm7zPc8M7IHXxeIyJamGw7eWcUq3XO89rB49r+TfR640tePOPM3X2VVWqUVbBy5GJiOTEoEKyigj0wGcP9zFpH2dHBQBgxcO90dLHtd625zKL6n1/xa6LKCyrxFd7U+q891bcWVy6UWJSbUREZF4MKiS7ri29TWp/V3QoAEChUOBOPSMyfdr4GeynvFKDbw9e1tpWoqrE9rNZUFWq8c2By3r2JCKipuIkdwFEzk4Kk9p7ud76sZ03tjO+P5Sm9f7mWUNw6XoJ4i/noV+E/sCyan/dUZRp38Xjr6QbGNdd/2kljUbAwcG0momIqGE4okKyc3Iw7cfQQXErJHi5OmPGHZEAgC8e6YvUxePRuYW31Ka+q5MvZBXX2fZXUtXVQBsTM/Xu9/3hNL3vERGReTGokOycTBydUNRqPmd0JyQuHIWRXUOkbdVd1rw8+UaxqsE11rRg/Smz9ENERIYxqJDsnBxNCyquzo51tnm5Omu9dnSoO6Iy5+cTWm3Wxl/Req0xYXE4XVcJERGR+TGokOxqh4xpt7fX27ZrS2+dQaW26lM/osaIypGU3Hr3+eGI8ad0vtqbotU3ERFZBoMKWZWFE7ripTGd9L6/8bkhRvWj0HHqx82l/oDzyq+mndKpbOTy/EREZBiDClkVR0cHKGpPQmkAaTLt3yvgl1WocaO4vNH91lSh5vL6RESWxqBCdunWVT9Vox4/HU03+2dUVHJEhYjI0hhUyCrc26sVAjxccHePqsXcPp7UCx1DPDGpfxjaBLib3F/1VT/VK9MWlVWardZq5RxRISKyOC74Rlbhowd6Qq0R0tU6d/cIlULLwUs5ePC/B7UuPzak5vSRU1cLdC6R31iVvLMyEZHFyTqismfPHkyYMAGhoaFQKBRYv369nOWQzBz1rKdyW7sAHH11BD434Z5ABTcrpOd3LduL3BLzzk8BqpbgJyIiy5I1qJSUlKBHjx5Yvny5nGWQDQj0VJq0bP31ojILVlNFreeqn1d+TcT9nx1AJU8NERE1mqynfsaOHYuxY8ca3V6lUkGlurW6aGFhoSXKIjugb3TGnPQFlep7D+2/mIOhHYN0tql9v6CyCjV+P3ENt3cKQrBX/XeEJiJqTmxqjkpsbCwWLVokdxlkAzyUlv/RVhtY8E1fkCm4WYExS/ZgWKdgvDSmEyZ9cQhnM6pCd7i/O/a8eIfZayUislU2ddXP/PnzUVBQID3S081/ySnZh+o//JakL4gYen/N4TRkFJThh8NpWLYjWavWtNxSs9ZIRGTrbCqoKJVKeHt7az2IdHl6WKTFP0PXRT81l9WvfS8hAJi/LhGxm85Jry1xNRIRkT2xqaBCZCx/Dxez9tch2BMPDQjX2qbr1E9SdrH0fPPpzDrv/3DY+PsJERERgwqRTv0j/LVeJ2UXo/b8XLWOIRVVhe4rfYrKKrD++FWz1UdE1FzIOpm2uLgYycnJ0uuUlBQkJCTA398f4eHh9exJZFlfP9Yf+TfLERO7Q9qmgHZS0XX1cc3Vau/v21p63n3hFvMXSUTUDMg6onL06FH06tULvXr1AgDMnj0bvXr1wmuvvSZnWURQOjmgpY+b1ra6IyoCao3Ayn0pOHOtakKsqlItvZ9RUAYN77BMRNQoso6oDBs2TGvyIVFT+HPWUCzdfgEbE+vOIQGAs2+MqbO43OsTuuJyjvYVOR9sOY+JPUOx6PczAIDUxeOhdLqV/f9KuoEfjqTBzdnRzEdARNR82NQ6KkSN1drPDZ1aeOG+3q31BhU3l1vB4uLb45CaU4L2QZ54fcMprXZHL+ch8WqB1rbfT2RovX7lV+19jCGEgEJh+QXriIhsASfTkt1yqjEq8vuMwZh2e3v88ORtAKqW5NclurWP1mtHBwXaB3kCAAZFBtZpr6pxv59KtQar9qc2tmz8eTqr0X0QEdkLjqiQ3fppWgz+8el+AFUjKfPGdpbei27tg0n9w9Eu0ANjurWAEMCWM5l4fHBbvf2N7BqClY/2w7Rv47UCSrUVuy6ape6fj6ZjTLcWZumLiMjWMaiQ3eod7oeZwzsAQsCv1roqCoUCsf/orrXtiSHt6u1PoVDgjk7BCPRU4mr+zTrvf7D1QuOLRtU9goiIqAqDCtm12SM7mr1PQ0vnN5azI+enEBFV4xwVIhMZuhlhY7k48SohIqJqDCpEJsopVlm0/xvFKvx4RP9S+8fT8pBZUAYAei/vr1BrkJ5bime+j8fxtDyL1ElE1BR46ofIRE2xhttLvyTigX7h0GgE9iRdR3RrX/h7uODU1QLc+/cE4ccHt8VXe1Ow44XbEeSlxPaz2RjRNQRbTmdi9k8npL42JmZiyQM90b21j3QFExGRrWBQIbJSpeWV2JBwDfPXJaKVrxv2zbsTh1Nypfer77x85we7cXePUPx24hrGRLXQeTPEWT8mAKhalI6IyJbw1A+RFRjeObjOttyScmxMrFpATtdVRjX9duIaAN13bCYismUMKkQy6xHmiy+n9MXzI7SvUHJQKPBX0g2zftb1IsvOryEiMjcGFSKZffxgTygUCvh5OGttH7h4h9br/NJyNHZl/bIKteFGRERWhEGFyETrnhmodwn+aq183eps++iBHvj8X32k1+f/MwYnF45CmwAPAFXL9den5xtbpRsgNlRhWUWj9o/ddBYR8+JwNqOwUf0QERmLk2mJTNQ73A9HXhmOtvM3Stt6tPbBiSu3blD4wqiOiG7tA0cHBzz+9RE82C8M9/ZqDQB4fkRHuLs4QulU9ah2o6jc4rU7Ozbs/01yS8rh7eqEz3dfAgCMXfoXJ+YSUZNgUCFqgNp3N175aH/svpCNxZvOIatQhaEdg6RRlx0vDNNq+9yIDjr7/GibeZbgr8+xy3noGOJldHshhFYgq2nhb6cxc3gH+Ne6PQERkTkxqBCZgb+HC+7t1RoTokNRWqGGt6uz4Z1kcPF6sUntj6TqXyxu1f5UHE7JxcbnhjS2LCIivThHhaiR2gV5SM+dHB0aHFKeHtbeXCXp1b9tQJ1txapK/H7iGhLS8xExLw4R8+Iw5+eqBeNW7U+pt78zGYWIfHkjNI1cBe9a/k0sWH/K5CBFRPaPIypEDbT1+aH4/lAanjFTwIgK9TZLP/WpUGvqbBv90Z4667Ssjb+CtfFXjOqzUiOw60I27uwc0uC6nv8xAYdScrEh4SpOLhzd4H6IyP5wRIWogTqEeGHh3VEI9nY1S39uzua5GeEbE6Ow44XbpdezR3aUrkJ65vtjOHNN+4odQ4vJGeOxVUdN3qdSrcGOc1koKK3AiSv5AIDCsspG10JE9oUjKkRW4k4dq9Ma689ZQzF6yR60CXDHIzERAICLb4+Dg6Jq4u+HW29N1B338V9IiR0HhUIh67oq7205j893X0LnFl4oq6g70kNEBDCoEFkNhUKB1MXjETEvzqj2xxeMxH//uoT7erdGZLBnncuF61uXZdDiHbj29x2Ym9qYJXuQXaRCbknV5djnMou03v9izyWcuJKPZZN61bm6ioiaHwYVIhuQ/NZY/CfuLFbtT5W2+Xm44KUxnRvUnyVDyr++OoQjqbkI9XXDorujMKRDkPReVmFZnWBS21sbzwIA/jiZgd7hvlg7bSAcDCyGZ6xDl3LwwH8PAgC+fqw/bu8YZGAPIpIbgwqRldk1Zxh+OXYFjw5qi9iNZxHkpYSTowMW3h2F1YfSUK7WoE8bP7nLrOPgpRz4ubtI9ye6dL0E//rqMACgbxs/VKg1WoviGeNYWj52J13HHZ1MPy2m0Qj8fvIahnepmuSrdHKQQgoATPnf4QYvWlep1sBBoZAC1Kp9KYjddA7xC0bCU8n/rBKZk0II0bjrCmVUWFgIHx8fFBQUwNvb8ldMEMntbEYhVu1LxXMjOiBUxzL9+rSbH4dGXkFcx7yxnXEkJRfbz2Wbt+Na/je1r8ErijQagR5vbMHoqBYoKquAEMCWM1kG+25IUCmv1GDMkj0I8HTBz9MGAoDW6brkt8bCqYErABM1F6b8/Wb0J7IhXVp6451/Rpu8354X78Dgd3Y2+vNXPzEAAyMDUaKqhIfSCc6ODhYPKjVvMwAAR1Nz4eXqjE4tbq2w++/v4lFUVmn0JdXGKq/UoFyt0RolSc4uxqUbJbh0owSFZRWIXrhFa5/7VuzHhhmDzVbDM9/H42J2CTbPGsI5O9QsMagQNQNBXvXfRPHTyb0xrntLVKg1uJBVhPEf79V6/+NJvRDg4YKBkYEAAI+//3D3DPO1SL01FZVV4s73d+HB/mG4vWMw/vnZAQDA/nl3ItTXDUu2XcBWI0ZPdBFC1PnjX1BagQqNBoGeSnR8dRMA4PSi0fBQOuH9P89rhaHaIQWAyae36pNXUo6NiZkAgAMXc6R/f6LmhKd+iJqJ2T8mYH3CVXQI9sL5rFsTWn3dnZHw2iittuuPX0Worxv6RfihrEIDNxfda7yUVajRecFmi9ZtSacWjZZGSxLS81Gh1mDat/HIKSnH0VdHoO9/tgEAFIqqsNL1tT+N6nfqwAhMHRiBg5dyEBHogSAvJSICPHAt/ybC/N3rtBdCoEhVKa1q/MmOJKTn3sSPR9O12v05a6jWSFK1sgo1dp2/jsKyCixYfwqrHu2PmPZ1VyH++Wg65q49iXXPDMQ/Pt0PANg37044KICWPsafSjSGrhBoigMXczB37Qn89O8Yk05zmvoZAgID2zMANjVT/n4zqBA1Q1fzb2LQ4h0Abo1MNNT8dSfxw+F0ww31mDwgHADwn3u6If5yHjqEeOHS9WLc+/cfUns0tGMQFt0dhbaBVbdfqJ7j8sq4LtJVT/rUnlfzx8lrmLH6eJ12E3uG4skh7bD1TBZuFKswoksIHl11pN6+tzw/tN6bVhaUVsDHXfctImb+cBy/nbiGZZN64dkfbtXT0AnL1f8mrs4OOPfmWABVYXvd8av4/F99MDqqhcl9Hk7JhZerEzqGeGHHuWw8+U3VQoVn3hgNIaqC3MLfz2Dpgz0xsWerBtXdVIQQuFmhhruLE26Wq/X+z4QuSVlFeOrbeNzXuxVm3HnrJqk3y9VQOjmg4GYFkq8Xo1+EvyVKB8CgQkRGSLlRAiEE2gV5Nrqvjq9sQrmO5fn1OffmGFzOKYWnq5O0am5txq4nU58wfzf8PmMwvj1wGZHBnnj6+2ON7tOchnYMwpSYNnj8a9NX9v3koV7o0doXQ95t/NyjmnqF++Ku6FBEBLijY4gXUnNKMKBtAB5ddRj7knMQEeCOnXOGQaFQYPGmc/hs98V6+2sX6IEdc4ZJr6vvyD2qawj++0hfvLv5HL7cm4IdL9yO1n5Vo00bEzPwTI3v6uwbY5BXWo6Bf4drwPgAlJCejy/2XMLZzEJcul6is82H9/fA7J9OaG2r7r/2FV4N8dqGU/jmwGVsem4IurTU/lul1gh8sOU8hnQI0jkKVt0mLbcUfu7OKKvQoIWPK2JityOjxjIDE3uGYumDveqt42hqLg6n5uLdzeelbdXHuf/iDTz0xSEM6RAoXbkHWG5yOIMKETWpsxmFmPzlIWkRt6S3xsJRoUB6Xim+O3gZX/yVgo8n9cKpqwXoFOKF+/q0NthnY4NKSx9X7Jo7TGsybvV/7trO39iovnXpGOKJvNIKXC9Smb1vW/fxpF6Y+UPdUZ/aurfyQeJV4+b47Jt3J05fLcBT38ZjYPsA9I3wx5bTmQbX6THWq+O74D9xt0a3ljzQE7N+TAAADOkQiP/c0w2Bnkp88dclbD2Thflju2Bwh0Ck3ijBsPd3SftFhXrjdI3bVnz2cB+M6dYCv8RfwTubzyEy2BP7L+YAqAobydnFyCpUYeNzgxHs5Yq9STfw8FeHtGob260FNp3KrFPzsE5BeGZYJEJ9XVFUVqkViqoDYm2BnkrcKK7/Z/alMZ3NftNUmwsqy5cvx3vvvYfMzEz06NEDy5YtQ//+/Q3ux6BCZF2yCsvg4+YMVzPct8jY+S/39mqFD+/vAYVCgWv5N7E+4Soe6h8OX3cXvft8uPUCPt6e1KC63rmvO/6vTxj+Sr6Bz3dfREJ6Pt7/vx4Y170lPt6epHW7AnNZcFdXvPnHGbP3S2Sshp7C08emLk/+8ccfMXv2bHz22WcYMGAAlixZgtGjR+P8+fMIDm74vU+IqOmFmOkGjQDg6uyI8/8Zg7IKDW5/byfySysAACsm98bY7i117hPq64ZnhkUa7Hv2yI4I9lLicEoubhSrMCgyEBeyirAh4ZrUpmtLb3z7eH98vucSrubdxJv3dIO/x63wc3vHoDor27YJqDtRtqF83Z3xzn3RuLNzMBQAgwo1W7KPqAwYMAD9+vXDJ598AgDQaDQICwvDs88+i3nz5tW7L0dUiMjcisoq4OWqe8KoMT7/e87G1EERUDo5Iquwah5BoKcSFWoNlE4OuJJXdfVPWYUaWYVlWBt/BX0j/NExxFPv1TdPfxevc7gfAOaO7oTySg2m3xEJZ0cFjqXl4eL1Eozv3hLuLo5QKBQ4eSUfbQI84O7iCOcacw6KVZXo9vqf8HV3lsJgtYHtA3At/yZSc0p1fu6oriF4+LY2aBfkAScHB3i7OaG0XI2AvwNdY0+xLXmgJ3JLyvFGjZCmq05j/aN3K6w7dhUAMHN4B/i6OWPNkTRcyCoGAFz4z1iM//gvJGUXN6pue9I73Bfrnhlk9n5t5tRPeXk53N3dsXbtWtxzzz3S9ilTpiA/Px8bNmzQaq9SqaBS3TqXVlhYiLCwMAYVImoWhBC4kncTXq5O8HFzttgCcKXllXB0UGjN71FVqqHWCLi7GD8QfzmnBMfS8vD8jycQ5u+Gv168E6Xllfh4ezL6t/VDfmkFxnVviRJVJbzdnPHT0XTc1T0UPd7Ygi4tvbHpuSE6+1VrBNq/XBWCIoM9EdMuAMfT83DqaiHe+2c0Dl7Kxd09QzG0Q6DJ/0ZCCHy1NwVLtyfB0UGBNyd2Q98IP5zLKMKgyECsPnQZbQI94O7siB+PpOONe7rBw8URV/JuIi4xA4s3ncNrd3XFY4PbAqiawNrKzw0tfdy05l11CqlaJqBjiCd+fCoGl3NLkZRVhI4hXtibfAMbEq7iicHtcDm3BEGeSly6UYL7+4ahWyufOjWfuVaIcR//hR5hvhjaIRDX8svQPthDmjTbvZUPIoM9cS3/JmYO74D80gqE+7uje2sf3ChWwUGhkEYLNRoBhQLIL62Au9KxzoKL5mIzQeXatWto1aoV9u/fj5iYGGn7iy++iN27d+PQIe0JRAsXLsSiRYvq9MOgQkREZDtMCSo2dUOK+fPno6CgQHqkpzd87QYiIiKyfrJOpg0MDISjoyOysrSXv87KykKLFnUX81EqlVAq618KnIiIiOyHrCMqLi4u6NOnD7Zv3y5t02g02L59u9apICIiImqeZL88efbs2ZgyZQr69u2L/v37Y8mSJSgpKcGjjz4qd2lEREQkM9mDygMPPIDr16/jtddeQ2ZmJnr27InNmzcjJCRE7tKIiIhIZrKvo9IYXEeFiIjI9tjtVT9ERETUvDCoEBERkdViUCEiIiKrxaBCREREVotBhYiIiKwWgwoRERFZLQYVIiIisloMKkRERGS1ZF+ZtjGq16orLCyUuRIiIiIyVvXfbWPWnLXpoFJUVAQACAsLk7kSIiIiMlVRURF8fHzqbWPTS+hrNBpcu3YNXl5eUCgUZu27sLAQYWFhSE9Pt8vl+Xl8ts/ej5HHZ9vs/fgA+z9GSx6fEAJFRUUIDQ2Fg0P9s1BsekTFwcEBrVu3tuhneHt72+UPYDUen+2z92Pk8dk2ez8+wP6P0VLHZ2gkpRon0xIREZHVYlAhIiIiq8WgoodSqcTrr78OpVIpdykWweOzffZ+jDw+22bvxwfY/zFay/HZ9GRaIiIism8cUSEiIiKrxaBCREREVotBhYiIiKwWgwoRERFZLQYVHZYvX46IiAi4urpiwIABOHz4sNwl6bRw4UIoFAqtR+fOnaX3y8rKMH36dAQEBMDT0xP33XcfsrKytPpIS0vD+PHj4e7ujuDgYMydOxeVlZVabXbt2oXevXtDqVQiMjISq1atssjx7NmzBxMmTEBoaCgUCgXWr1+v9b4QAq+99hpatmwJNzc3jBgxAklJSVptcnNzMXnyZHh7e8PX1xePP/44iouLtdqcPHkSQ4YMgaurK8LCwvDuu+/WqeXnn39G586d4erqiu7du2Pjxo0WP76pU6fW+T7HjBljM8cXGxuLfv36wcvLC8HBwbjnnntw/vx5rTZN+TNp7t9jY45v2LBhdb7DadOm2cTxAcCKFSsQHR0tLfAVExODTZs2Se/b8vdnzPHZ+vdX2+LFi6FQKDBr1ixpm01+h4K0rFmzRri4uIj//e9/4vTp0+LJJ58Uvr6+IisrS+7S6nj99ddFVFSUyMjIkB7Xr1+X3p82bZoICwsT27dvF0ePHhW33XabGDhwoPR+ZWWl6NatmxgxYoQ4fvy42LhxowgMDBTz58+X2ly6dEm4u7uL2bNnizNnzohly5YJR0dHsXnzZrMfz8aNG8Urr7wi1q1bJwCIX3/9Vev9xYsXCx8fH7F+/Xpx4sQJcffdd4u2bduKmzdvSm3GjBkjevToIQ4ePCj++usvERkZKSZNmiS9X1BQIEJCQsTkyZPFqVOnxA8//CDc3NzE559/LrXZt2+fcHR0FO+++644c+aMePXVV4Wzs7NITEy06PFNmTJFjBkzRuv7zM3N1Wpjzcc3evRosXLlSnHq1CmRkJAgxo0bJ8LDw0VxcbHUpql+Ji3xe2zM8d1+++3iySef1PoOCwoKbOL4hBDit99+E3FxceLChQvi/Pnz4uWXXxbOzs7i1KlTQgjb/v6MOT5b//5qOnz4sIiIiBDR0dHiueeek7bb4nfIoFJL//79xfTp06XXarVahIaGitjYWBmr0u31118XPXr00Plefn6+cHZ2Fj///LO07ezZswKAOHDggBCi6g+ng4ODyMzMlNqsWLFCeHt7C5VKJYQQ4sUXXxRRUVFafT/wwANi9OjRZj4abbX/kGs0GtGiRQvx3nvvSdvy8/OFUqkUP/zwgxBCiDNnzggA4siRI1KbTZs2CYVCIa5evSqEEOLTTz8Vfn5+0vEJIcRLL70kOnXqJL2+//77xfjx47XqGTBggPj3v/9tseMToiqoTJw4Ue8+tnR8QgiRnZ0tAIjdu3cLIZr2Z7Ipfo9rH58QVX/oav5RqM2Wjq+an5+f+PLLL+3u+6t9fELYz/dXVFQkOnToILZu3ap1TLb6HfLUTw3l5eWIj4/HiBEjpG0ODg4YMWIEDhw4IGNl+iUlJSE0NBTt2rXD5MmTkZaWBgCIj49HRUWF1rF07twZ4eHh0rEcOHAA3bt3R0hIiNRm9OjRKCwsxOnTp6U2NfuobtPU/x4pKSnIzMzUqsXHxwcDBgzQOh5fX1/07dtXajNixAg4ODjg0KFDUpuhQ4fCxcVFajN69GicP38eeXl5Uhu5jnnXrl0IDg5Gp06d8PTTTyMnJ0d6z9aOr6CgAADg7+8PoOl+Jpvq97j28VX7/vvvERgYiG7dumH+/PkoLS2V3rOl41Or1VizZg1KSkoQExNjd99f7eOrZg/f3/Tp0zF+/Pg6ddjqd2jTNyU0txs3bkCtVmt9QQAQEhKCc+fOyVSVfgMGDMCqVavQqVMnZGRkYNGiRRgyZAhOnTqFzMxMuLi4wNfXV2ufkJAQZGZmAgAyMzN1Hmv1e/W1KSwsxM2bN+Hm5maho9NWXY+uWmrWGhwcrPW+k5MT/P39tdq0bdu2Th/V7/n5+ek95uo+LGXMmDH4xz/+gbZt2+LixYt4+eWXMXbsWBw4cACOjo42dXwajQazZs3CoEGD0K1bN+nzm+JnMi8vz+K/x7qODwAeeughtGnTBqGhoTh58iReeuklnD9/HuvWrbOZ40tMTERMTAzKysrg6emJX3/9FV27dkVCQoJdfH/6jg+wj+9vzZo1OHbsGI4cOVLnPVv9HWRQsWFjx46VnkdHR2PAgAFo06YNfvrppyYLEGQ+Dz74oPS8e/fuiI6ORvv27bFr1y4MHz5cxspMN336dJw6dQp79+6VuxSL0Hd8Tz31lPS8e/fuaNmyJYYPH46LFy+iffv2TV1mg3Tq1AkJCQkoKCjA2rVrMWXKFOzevVvussxG3/F17drV5r+/9PR0PPfcc9i6dStcXV3lLsdseOqnhsDAQDg6OtaZAZ2VlYUWLVrIVJXxfH190bFjRyQnJ6NFixYoLy9Hfn6+Vpuax9KiRQudx1r9Xn1tvL29mzQMVddT33fTokULZGdna71fWVmJ3NxcsxxzU/8MtGvXDoGBgUhOTpbqsoXjmzFjBv744w/s3LkTrVu3lrY31c+kpX+P9R2fLgMGDAAAre/Q2o/PxcUFkZGR6NOnD2JjY9GjRw8sXbrUbr4/fceni619f/Hx8cjOzkbv3r3h5OQEJycn7N69Gx9//DGcnJwQEhJik98hg0oNLi4u6NOnD7Zv3y5t02g02L59u9Y5TGtVXFyMixcvomXLlujTpw+cnZ21juX8+fNIS0uTjiUmJgaJiYlaf/y2bt0Kb29vaSg0JiZGq4/qNk3979G2bVu0aNFCq5bCwkIcOnRI63jy8/MRHx8vtdmxYwc0Go30H5yYmBjs2bMHFRUVUputW7eiU6dO8PPzk9pYwzFfuXIFOTk5aNmypVSXNR+fEAIzZszAr7/+ih07dtQ5BdVUP5OW+j02dHy6JCQkAIDWd2itx6ePRqOBSqWy+e/P0PHpYmvf3/Dhw5GYmIiEhATp0bdvX0yePFl6bpPfocnTb+3cmjVrhFKpFKtWrRJnzpwRTz31lPD19dWaAW0tXnjhBbFr1y6RkpIi9u3bJ0aMGCECAwNFdna2EKLqMrTw8HCxY8cOcfToURETEyNiYmKk/asvQxs1apRISEgQmzdvFkFBQTovQ5s7d644e/asWL58ucUuTy4qKhLHjx8Xx48fFwDEhx9+KI4fPy4uX74shKi6PNnX11ds2LBBnDx5UkycOFHn5cm9evUShw4dEnv37hUdOnTQunw3Pz9fhISEiH/961/i1KlTYs2aNcLd3b3O5btOTk7i/fffF2fPnhWvv/66WS7fre/4ioqKxJw5c8SBAwdESkqK2LZtm+jdu7fo0KGDKCsrs4nje/rpp4WPj4/YtWuX1uWdpaWlUpum+pm0xO+xoeNLTk4Wb7zxhjh69KhISUkRGzZsEO3atRNDhw61ieMTQoh58+aJ3bt3i5SUFHHy5Ekxb948oVAoxJYtW4QQtv39GTo+e/j+dKl9JZMtfocMKjosW7ZMhIeHCxcXF9G/f39x8OBBuUvS6YEHHhAtW7YULi4uolWrVuKBBx4QycnJ0vs3b94UzzzzjPDz8xPu7u7i3nvvFRkZGVp9pKamirFjxwo3NzcRGBgoXnjhBVFRUaHVZufOnaJnz57CxcVFtGvXTqxcudIix7Nz504BoM5jypQpQoiqS5QXLFggQkJChFKpFMOHDxfnz5/X6iMnJ0dMmjRJeHp6Cm9vb/Hoo4+KoqIirTYnTpwQgwcPFkqlUrRq1UosXry4Ti0//fST6Nixo3BxcRFRUVEiLi7OosdXWloqRo0aJYKCgoSzs7No06aNePLJJ+v8Ulvz8ek6NgBaPy9N+TNp7t9jQ8eXlpYmhg4dKvz9/YVSqRSRkZFi7ty5WutwWPPxCSHEY489Jtq0aSNcXFxEUFCQGD58uBRShLDt78/Q8dnD96dL7aBii9+hQgghTB+HISIiIrI8zlEhIiIiq8WgQkRERFaLQYWIiIisFoMKERERWS0GFSIiIrJaDCpERERktRhUiIiIyGoxqBAREZHVYlAhokaJiIjAkiVLjG6/a9cuKBSKOjdGIyLShSvTEjUzw4YNQ8+ePU0KF/W5fv06PDw84O7ublT78vJy5ObmIiQkBAqFwiw1mGrXrl244447kJeXB19fX1lqICLjOMldABFZHyEE1Go1nJwM/yciKCjIpL5dXFwafTt7Imo+eOqHqBmZOnUqdu/ejaVLl0KhUEChUCA1NVU6HbNp0yb06dMHSqUSe/fuxcWLFzFx4kSEhITA09MT/fr1w7Zt27T6rH3qR6FQ4Msvv8S9994Ld3d3dOjQAb/99pv0fu1TP6tWrYKvry/+/PNPdOnSBZ6enhgzZgwyMjKkfSorKzFz5kz4+voiICAAL730EqZMmYJ77rlH77FevnwZEyZMgJ+fHzw8PBAVFYWNGzciNTUVd9xxBwDAz88PCoUCU6dOBVB1K/rY2Fi0bdsWbm5u6NGjB9auXVun9ri4OERHR8PV1RW33XYbTp061cBvhIgMYVAhakaWLl2KmJgYPPnkk8jIyEBGRgbCwsKk9+fNm4fFixfj7NmziI6ORnFxMcaNG4ft27fj+PHjGDNmDCZMmIC0tLR6P2fRokW4//77cfLkSYwbNw6TJ09Gbm6u3valpaV4//338e2332LPnj1IS0vDnDlzpPffeecdfP/991i5ciX27duHwsJCrF+/vt4apk+fDpVKhT179iAxMRHvvPMOPD09ERYWhl9++QUAcP78eWRkZGDp0qUAgNjYWHzzzTf47LPPcPr0aTz//PN4+OGHsXv3bq2+586diw8++ABHjhxBUFAQJkyYgIqKinrrIaIGatA9l4nIZtW+7bsQVbdsByDWr19vcP+oqCixbNky6XWbNm3ERx99JL0GIF599VXpdXFxsQAgNm3apPVZeXl5QgghVq5cKQCI5ORkaZ/ly5eLkJAQ6XVISIh47733pNeVlZUiPDxcTJw4UW+d3bt3FwsXLtT5Xu0ahBCirKxMuLu7i/3792u1ffzxx8WkSZO09luzZo30fk5OjnBzcxM//vij3lqIqOE4R4WIJH379tV6XVxcjIULFyIuLg4ZGRmorKzEzZs3DY6oREdHS889PDzg7e2N7Oxsve3d3d3Rvn176XXLli2l9gUFBcjKykL//v2l9x0dHdGnTx9oNBq9fc6cORNPP/00tmzZghEjRuC+++7Tqqu25ORklJaWYuTIkVrby8vL0atXL61tMTEx0nN/f3906tQJZ8+e1ds3ETUcgwoRSTw8PLRez5kzB1u3bsX777+PyMhIuLm54Z///CfKy8vr7cfZ2VnrtUKhqDdU6GovGnlB4hNPPIHRo0cjLi4OW7ZsQWxsLD744AM8++yzOtsXFxcDAOLi4tCqVSut95RKZaNqIaKG4xwVombGxcUFarXaqLb79u3D1KlTce+996J79+5o0aIFUlNTLVtgLT4+PggJCcGRI0ekbWq1GseOHTO4b1hYGKZNm4Z169bhhRdewBdffAGg6t+gup9qXbt2hVKpRFpaGiIjI7UeNefxAMDBgwel53l5ebhw4QK6dOnSqOMkIt04okLUzERERODQoUNITU2Fp6cn/P399bbt0KED1q1bhwkTJkChUGDBggX1joxYyrPPPovY2FhERkaic+fOWLZsGfLy8updh2XWrFkYO3YsOnbsiLy8POzcuVMKE23atIFCocAff/yBcePGwc3NDV5eXpgzZw6ef/55aDQaDB48GAUFBdi3bx+8vb0xZcoUqe833ngDAQEBCAkJwSuvvILAwMB6r0AioobjiApRMzNnzhw4Ojqia9euCAoKqne+yYcffgg/Pz8MHDgQEyZMwOjRo9G7d+8mrLbKSy+9hEmTJuGRRx5BTEwMPD09MXr0aLi6uurdR61WY/r06ejSpQvGjBmDjh074tNPPwUAtGrVCosWLcK8efMQEhKCGTNmAADefPNNLFiwALGxsdJ+cXFxaNu2rVbfixcvxnPPPYc+ffogMzMTv//+uzRKQ0TmxZVpicjmaDQadOnSBffffz/efPPNJvtcrmhL1PR46oeIrN7ly5exZcsW3H777VCpVPjkk0+QkpKChx56SO7SiMjCeOqHiKyeg4MDVq1ahX79+mHQoEFITEzEtm3bOIGVqBngqR8iIiKyWhxRISIiIqvFoEJERERWi0GFiIiIrBaDChEREVktBhUiIiKyWgwqREREZLUYVIiIiMhqMagQERGR1fp/pg5h03eWrXUAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "\n",
    "# 训练序列到序列模型\n",
    "def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\\\n",
    "        learning_rate=1e-3):\n",
    "    # 准备模型和优化器\n",
    "    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "    encoder.zero_grad()\n",
    "    decoder.zero_grad()\n",
    "\n",
    "    step_losses = []\n",
    "    plot_losses = []\n",
    "    with trange(n_epochs, desc='epoch', ncols=60) as pbar:\n",
    "        for epoch in pbar:\n",
    "            np.random.shuffle(train_data)\n",
    "            for step, data in enumerate(train_data):\n",
    "                # 将源序列和目标序列转为 1 * seq_len 的tensor\n",
    "                # 这里为了简单实现，采用了批次大小为1，\n",
    "                # 当批次大小大于1时，编码器需要进行填充\n",
    "                # 并且返回最后一个非填充词的隐状态，\n",
    "                # 解码也需要进行相应的处理\n",
    "                input_ids, target_ids = data\n",
    "                input_tensor, target_tensor = \\\n",
    "                    torch.tensor(input_ids).unsqueeze(0),\\\n",
    "                    torch.tensor(target_ids).unsqueeze(0)\n",
    "\n",
    "                encoder_optimizer.zero_grad()\n",
    "                decoder_optimizer.zero_grad()\n",
    "\n",
    "                encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "                # 输入目标序列用于teacher forcing训练\n",
    "                decoder_outputs, _, _ = decoder(encoder_outputs,\\\n",
    "                    encoder_hidden, target_tensor)\n",
    "\n",
    "                loss = criterion(\n",
    "                    decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
    "                    target_tensor.view(-1)\n",
    "                )\n",
    "                pbar.set_description(f'epoch-{epoch}, '+\\\n",
    "                    f'loss={loss.item():.4f}')\n",
    "                step_losses.append(loss.item())\n",
    "                # 实际训练批次为1，训练损失波动过大\n",
    "                # 将多步损失求平均可以得到更平滑的训练曲线，便于观察\n",
    "                plot_losses.append(np.mean(step_losses[-32:]))\n",
    "                loss.backward()\n",
    "\n",
    "                encoder_optimizer.step()\n",
    "                decoder_optimizer.step()\n",
    "\n",
    "    plot_losses = np.array(plot_losses)\n",
    "    plt.plot(range(len(plot_losses)), plot_losses)\n",
    "    plt.xlabel('training step')\n",
    "    plt.ylabel('loss')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "hidden_size = 128\n",
    "n_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "\n",
    "encoder = RNNEncoder(input_lang.n_words, hidden_size)\n",
    "decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)\n",
    "\n",
    "train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be47e115",
   "metadata": {},
   "source": [
    "下面实现贪心搜索解码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "678192b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 考 研 英 语 写 作 考 前 冲 刺 点 题\n",
      "target： studying the english writing pre-test puncture question .\n",
      "pred： studying the english writing pre-test puncture question .\n",
      "\n",
      "input： p y t h o n 数 据 可 视 化 实 战\n",
      "target： python data visualization combat\n",
      "pred： python data visualization combat data visualization combat\n",
      "\n",
      "input： 路 由 与 交 换 技 术\n",
      "target： routes and exchange technologies\n",
      "pred： routes and exchange technologies\n",
      "\n",
      "input： 3 d s m a x / v r a y 室 内 家 装 工 装 效 果 图 表 现 技 法 （ 微 课 版 ）\n",
      "target： 3ds max/vray domestic bunker effects performance techniques (micro-pedagogical version)\n",
      "pred： 3ds max/vray domestic bunker effects performance techniques (micro-pedagogical version)\n",
      "\n",
      "input： 计 量 经 济 学 理 论 与 应 用 — — 基 于 e v i e w s 的 应 用 分 析\n",
      "target： theory and application of econometrics — application analysis based on eviews\n",
      "pred： theory and application of econometrics — application analysis based on eviews\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "        \n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        decoder_outputs, decoder_hidden, decoder_attn = \\\n",
    "            decoder(encoder_outputs, encoder_hidden)\n",
    "        \n",
    "        # 取出每一步预测概率最大的词\n",
    "        _, topi = decoder_outputs.topk(1)\n",
    "        \n",
    "        decoded_ids = []\n",
    "        for idx in topi.squeeze():\n",
    "            if idx.item() == EOS_token:\n",
    "                break\n",
    "            decoded_ids.append(idx.item())\n",
    "    return output_lang.ids2sent(decoded_ids), decoder_attn\n",
    "            \n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],\n",
    "        input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e38320f0",
   "metadata": {},
   "source": [
    "\n",
    "接下来使用束搜索解码来验证模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3496efa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 短 视 频 运 营 ： 从 入 门 到 精 通 （ 微 课 版 ）\n",
      "target： short video operation: from entry to mastery (microtext)\n",
      "pred： short video operation: from entry to mastery (microtext)\n",
      "\n",
      "input： 从 零 开 始 ： p h o t o s h o p 工 具 详 解 与 实 战\n",
      "target： from scratch: photoshop tool detailed and operational\n",
      "pred： from scratch: photoshop tool detailed\n",
      "\n",
      "input： 会 计 转 型 与 进 阶\n",
      "target： accounting transition and progression\n",
      "pred： accounting transition and progression\n",
      "\n",
      "input： 我 的 世 界 高 手 进 阶 指 南 m i n e c r a f t 模 组 m o d 开 发\n",
      "target： my world's best step guide , minecraft model mod development .\n",
      "pred： my world's best step guide , minecraft model mod development .\n",
      "\n",
      "input： 年 岁\n",
      "target： age\n",
      "pred： age age age\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义容器类用于管理所有的候选结果\n",
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length):\n",
    "        self.max_length = max_length\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "    \n",
    "    # 添加一个候选结果，更新最差得分\n",
    "    def add(self, sum_logprobs, hyp, hidden):\n",
    "        score = sum_logprobs / max(len(hyp), 1)\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp, hidden))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                    (s, _, _) in enumerate(self.beams)])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "    \n",
    "    # 取出一个未停止的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        for i, (s, hyp, hid) in enumerate(self.beams):\n",
    "            # 未停止的候选结果需满足：长度小于最大解码长度；不以<eos>结束\n",
    "            if len(hyp) < self.max_length and (len(hyp) == 0\\\n",
    "                    or hyp[-1] != EOS_token):\n",
    "                del self.beams[i]\n",
    "                if len(self) > 0:\n",
    "                    sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                        (s, _, _) in enumerate(self.beams)])\n",
    "                    self.worst_score = sorted_scores[0][0]\n",
    "                else:\n",
    "                    self.worst_score = 1e9\n",
    "                return True, (s, hyp, hid)\n",
    "        return False, None\n",
    "    \n",
    "    # 取出分数最高的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop_best(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\\\n",
    "            in enumerate(self.beams)])\n",
    "        return True, self.beams[sorted_scores[-1][1]]\n",
    "\n",
    "\n",
    "def beam_search_decode(encoder, decoder, sentence, input_lang,\n",
    "        output_lang, num_beams=3):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "\n",
    "        # 在容器中插入一个空的候选结果\n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        init_hyp = []\n",
    "        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)\n",
    "        hypotheses.add(0, init_hyp, encoder_hidden)\n",
    "\n",
    "        while True:\n",
    "            # 每次取出一个未停止的候选结果\n",
    "            flag, item = hypotheses.pop()\n",
    "            if not flag:\n",
    "                break\n",
    "                \n",
    "            score, hyp, decoder_hidden = item\n",
    "            \n",
    "            # 当前解码器输入\n",
    "            if len(hyp) > 0:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(hyp[-1])\n",
    "            else:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(SOS_token)\n",
    "\n",
    "            # 解码一步\n",
    "            decoder_output, decoder_hidden, _ = decoder.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "\n",
    "            # 从输出分布中取出前k个结果\n",
    "            topk_values, topk_ids = decoder_output.topk(num_beams)\n",
    "            # 生成并添加新的候选结果到容器\n",
    "            for logp, token_id in zip(topk_values.squeeze(),\\\n",
    "                    topk_ids.squeeze()):\n",
    "                sum_logprobs = score * len(hyp) + logp.item()\n",
    "                new_hyp = hyp + [token_id.item()]\n",
    "                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)\n",
    "\n",
    "        flag, item = hypotheses.pop_best()\n",
    "        if flag:\n",
    "            hyp = item[1]\n",
    "            if hyp[-1] == EOS_token:\n",
    "                del hyp[-1]\n",
    "            return output_lang.ids2sent(hyp)\n",
    "        else:\n",
    "            return ''\n",
    "\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence = beam_search_decode(encoder, decoder,\\\n",
    "        pair[0], input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80c8c32-7540-4f51-9af7-c6dcbb306e8c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d993ff05-b61b-4576-ac68-6e3e90e02228",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa9edd9-2768-47a0-8319-338448bff581",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a92b73b-9206-453a-908b-62cd215235f4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
